# 模型库
import torch
import torchvision

# 进度条显示
from tqdm import tqdm

# 绘图库
import matplotlib.pyplot as plt

# 命令行参数获取
import utils.parameters

# 网络结构
from models.net import Net

if __name__ == "__main__":
    # 如果网络能在GPU中训练，就使用GPU；否则使用CPU进行训练
    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    # 定义 图像预处理器 对象
    # 将图片转换为张量(多维数组，存储浮点数)
    transform_func = torchvision.transforms.ToTensor()

    # 读取命令行参数
    params_parser = utils.parameters.get_train_args()

    # 设置批大小和训练轮数
    # 数据集的图片按照批大小分批，训练完所有批次后，算为一轮
    BATCH_SIZE = params_parser.batch_size
    EPOCHS = params_parser.epochs

    # 是否只是下载数据集
    IS_DOWNLOAD_DATASET = params_parser.download_dataset

    # 加载训练和测试数据
    train_dataset = torchvision.datasets.CIFAR10(
        "./data/", train=True, transform=transform_func, download=IS_DOWNLOAD_DATASET
    )
    test_dataset = torchvision.datasets.CIFAR10(
        "./data/", train=False, transform=transform_func, download=IS_DOWNLOAD_DATASET
    )

    if IS_DOWNLOAD_DATASET:
        print("数据集下载完成！")
        exit()

    # 建立数据迭代器，随机加载
    # 装载训练集对象
    train_loader = torch.utils.data.DataLoader(
        dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True
    )
    # 装载测试集
    test_loader = torch.utils.data.DataLoader(
        dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=True
    )

    # DataLoader容器为封装后的线性表，当成链表用即可
    # 每个元素是一个元组，包含一个批次的数据和标签，以张量形式存储
    # 共ceil(len/BATCH_SIZE)个元组，每个元组包含一个批次的数据和标签
    # 元素的数据结构如下：
    # ( 图像张量, 标签张量 )
    # 图像张量数据结构：
    # double[批大小][通道数][图片高度][图片宽度]
    # 标签张量数据结构：
    # double[批大小]

    # 是否只是查看数据的形状
    if params_parser.show_datashape:
        # 定义标签对应的字符串
        label_str = [
            "airplane",
            "automobile",
            "bird",
            "cat",
            "deer",
            "dog",
            "frog",
            "horse",
            "ship",
            "truck",
        ]
        # 获取第一个批次的数据及其对应的索引。
        # `example_data` 是图像张量，`example_targets` 是对应的标签
        (example_data, example_targets) = next(iter(train_loader))
        # 创建一个新的图形对象用于绘图
        fig = plt.figure()
        # 使用 for 循环绘制当前批次的前 16 个图像
        for i in range(16):
            # 在 4x4 的网格中创建子图，并指定当前子图的位置
            plt.subplot(4, 4, i + 1)
            # 调整子图之间的间距以避免重叠
            plt.tight_layout()
            # 显示图像
            plt.imshow(torchvision.transforms.ToPILImage()(example_data[i]))
            # 设置子图标题为该样本的真实标签
            plt.title("Label: {}".format(label_str[example_targets[i]]))
            # 隐藏 x 轴和 y 轴的刻度
            plt.xticks([])
            plt.yticks([])
        # 显示所有绘制的图像。
        plt.show()
        # 打印出当前批次数据的形状，了解数据的维度信息。
        print("数据形状(批次大小, 通道数, 图片高度, 图片宽度):", example_data.shape)
        exit()

    # 构建模型实例
    net = Net()
    # 判断是否继续训练
    if params_parser.continue_train_model:
        # 加载模型参数
        net.load_state_dict(
            torch.load(
                "./checkpoints/" + params_parser.continue_train_model_name,
                weights_only=True,
                map_location=torch.device(device),
            )
        )
    # 设置网络实例在设备上运行
    net = net.to(device)

    # 交叉熵损失函数
    loss_fun = torch.nn.CrossEntropyLoss()
    # Adam优化器
    optimizer = torch.optim.Adam(net.parameters())

    # 训练网络

    # 记录训练过程中的损失和准确率
    history = {"Test Loss": [], "Test Accuracy": []}
    # 训练循环
    for epoch in range(1, EPOCHS + 1):
        # 进度条对象，遍历一遍train_loader就是一轮(epoch)
        process_bar = tqdm(train_loader, unit="step")
        # 切换网络为训练模式
        net.train(True)
        # 每次循环都能训练一个批次，进度条将在循环中更新
        for train_imgs, labels in process_bar:
            # 将 图片张量 和 标签张量 设置到设备上
            train_imgs = train_imgs.to(device)
            labels = labels.to(device)
            # 清零梯度，防止累积
            net.zero_grad()
            # 前向传播一个批次的图像数据
            outputs = net(train_imgs)
            # 计算本批次的损失
            loss = loss_fun(outputs, labels)
            # 获取预测结果，dim=1表示取第2维的最大值，也就是选出可能性最大的那个类别
            predictions = torch.argmax(outputs, dim=1)
            # 计算准确率，将批次内预测正确的样本数除以总样本数
            accuracy = torch.true_divide(
                torch.sum(predictions == labels), labels.shape[0]
            )
            # 反向传播，计算梯度
            loss.backward()
            # 更新参数
            optimizer.step()

            # 更新进度条显示
            # 显示当前批次的损失和准确率
            process_bar.set_description(
                "[%d/%d] Loss: %.4f, Acc: %.4f, Progress"
                % (epoch, EPOCHS, loss.item(), accuracy.item())
            )

        # 训练完最后一个批次后，使用测试集评估模型效果
        correct, total_loss = 0, 0
        # 切换网络为推理模式
        net.train(False)
        # 在作用域内不计算梯度，节省内存
        with torch.no_grad():
            process_bar = tqdm(test_loader, unit="step")
            # 执行一轮测试
            for i, (test_imgs, labels) in enumerate(process_bar):
                # 将 图片张量 和 标签张量 设置到设备上
                test_imgs = test_imgs.to(device)
                labels = labels.to(device)
                # 前向传播一个批次的图像数据
                outputs = net(test_imgs)
                loss = loss_fun(outputs, labels)
                predictions = torch.argmax(outputs, dim=1)
                # 求损失
                total_loss += loss
                # 累加正确个数
                correct += torch.sum(predictions == labels)
                # 显示进度
                process_bar.set_description("Testing, Progress")

            # 计算测试集的准确率，正确个数除以总样本数
            test_accuracy = torch.true_divide(correct, (BATCH_SIZE * len(test_loader)))
            # 计算每个批次的平均损失
            test_loss = torch.true_divide(total_loss, len(test_loader))
            # 加入训练过程中的损失和准确率列表中
            history["Test Loss"].append(test_loss.item())
            history["Test Accuracy"].append(test_accuracy.item())

        # 显示最终结果
        print(
            "Epoch[%d] Loss: %.4f, Acc: %.4f, Test Loss: %.4f, Test Acc: %.4f"
            % (
                epoch,
                loss.item(),
                accuracy.item(),
                test_loss.item(),
                test_accuracy.item(),
            )
        )

    # 是否静默保存模型
    if params_parser.quiet_save_model:
        print("Test Loss:", history["Test Loss"])
        print("Test Accuracy:", history["Test Accuracy"])
        torch.save(net.state_dict(), "./checkpoints/" + params_parser.save_model_name)
        exit()

    # 绘制训练过程中的损失和准确率曲线
    # 对测试Loss进行可视化
    plt.plot(history["Test Loss"], label="Test Loss")
    plt.legend(loc="best")
    plt.grid(True)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.show()

    # 对测试准确率进行可视化
    plt.plot(history["Test Accuracy"], color="red", label="Test Accuracy")
    plt.legend(loc="best")
    plt.grid(True)
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.show()

    print("save model?(y/n)")
    is_save_model = input()
    if is_save_model == "y":
        print("saving model...")
        torch.save(net.state_dict(), "./checkpoints/" + params_parser.save_model_name)
        print("save model at:", "/checkpoints/" + params_parser.save_model_name)
